#include "assembly.hpp"

#include <cassert>
#include <sstream>
#include <iostream>

#include "fasta_reader.hpp"
#include "graph.hpp"
#include "logger.hpp"
#include "simple_align.hpp"

Assembly::Assembly() : ol_store_(read_store_){
}

bool Assembly::ParseArgument(int argc, char* const argv[]) {
    return GetArgumentParser().ParseArgument(argc, argv);
}

void Assembly::Run() {

    PrintArguments();

    DUMPER.SetLevel(options_.dump);


    LOG(INFO)("Start");
    LOG(INFO)("Load Overlaps");
    LoadOverlaps(options_.overlap_file);

    LOG(INFO)("Create StringGraph");
    CreateStringGraph();

    LOG(INFO)("Create PathGraph");
    CreatePathGraph();

    LOG(INFO)("Save Graph");
    SaveGraph();

    LOG(INFO)("Start output fasta");
    read_store_.SaveIdToName(OutputPath("id2name.txt"));
    if (!options_.read_file.empty()) {

        LOG(INFO)("Load fasta");
        LoadReads(options_.read_file);

        LOG(INFO)("Save Contigs");
        SaveContigs();
    }

    LOG(INFO)("End");
}

void Assembly::Usage() {
    std::cout << GetArgumentParser().Usage();
}

ArgumentParser Assembly::GetArgumentParser() {
    ArgumentParser ap("fsa_assemble", "Constructs contigs from filtered overlaps and corrected reads.", "1.0");
    ap.AddNamedOption(options_.min_length, "min_length", "minimum length of reads");
    ap.AddNamedOption(options_.min_identity, "min_identity", "minimum identity of overlaps");
    ap.AddNamedOption(options_.min_aligned_length, "min_aligned_length", "minimum aligned length of overlaps");
    ap.AddNamedOption(options_.min_contig_length, "min_contig_length", "minimum length of contigs");
    ap.AddNamedOption(options_.read_file, "read_file", "read filename");
    ap.AddNamedOption(options_.overlap_file_type, "overlap_file_type", "overlap file format. \"\" = filename extension, \"m4\" = M4 format, \"paf\" = PAF format generated by minimap2, \"ovl\" = OVL format generated by FALCON.", "\"|m4|m4a|paf|ovl\"");
    ap.AddNamedOption(options_.output_directory, "output_directory", "directory for output files");
    ap.AddNamedOption(options_.select_branch, "select_branch", "selecting method when encountering branches in the graph, \"no\" = do not select any branch, \"best\" = select the most probable branch", "\"no|best\"");
    ap.AddNamedOption(options_.thread_size, "thread_size", "number of threads");
    ap.AddNamedOption(options_.dump, "dump", "for testing, dump intermediate files");
    ap.AddNamedOption(options_.run_mode, "run_mode", "for testing");
    ap.AddNamedOption(options_.lfc, "lfc", "deprecated, for testing");
    ap.AddNamedOption(options_.remove_chimer, "remove_chimer", "deprecated, remove chimer node");

    ap.AddPositionOption(options_.overlap_file, "filterd_overlaps", "input filename");
        
    return ap;
}

void Assembly::LoadOverlaps(const std::string &fname) {
    LOG(INFO)("Start Load Overlaps");

    ol_store_.Load(fname, options_.overlap_file_type);
    LOG(INFO)("End Load Overlaps: size = %d", ol_store_.Size());
    if (ol_store_.Size() == 0) LOG(FATAL)("No overlaps were loaded");
}

void Assembly::LoadReads(const std::string &fname) {
    read_store_.Load(fname, "", options_.run_mode);
}

void Assembly::CreateStringGraph() {
    LOG(INFO)("AddOverlaps");
    string_graph_.AddOverlaps(ol_store_.Get(), options_.min_length, options_.min_aligned_length, options_.min_identity);

    LOG(INFO)("MarkTransitiveEdges");
    string_graph_.MarkTransitiveEdges();

    LOG(INFO)("MarkChimerEdges");
    if (options_.remove_chimer) {
        string_graph_.MarkChimerEdges();
    }
    string_graph_.MarkSpurEdges();

    LOG(INFO)("MarkBestOverlap");
    if (options_.lfc) {
        string_graph_.ResolveRepeatEdges();
    }
    else {
        string_graph_.MarkBestOverlap();
    }

    string_graph_.MarkSpurEdges();
    string_graph_.IdentifySimplePaths();
}

void Assembly::CreatePathGraph() {

    DUMPER.SetDirectory(options_.output_directory);
    for (auto &path : string_graph_.GetPaths()) {
        path_graph_.AddEdge(path);
    }


    if (options_.dump > 0) 
        path_graph_.Dump(options_.output_directory + "/path_graph_0.txt");

    path_graph_.IdentifyPathSpur();

    if (options_.dump > 0)
        path_graph_.Dump(options_.output_directory + "/path_graph_1.txt");

    path_graph_.RemoveDuplicateSimplePath();
    if (options_.dump > 0)
        path_graph_.Dump(options_.output_directory + "/path_graph_2.txt");

    //path_graph_.RemoveCrossEdges();
    path_graph_.ConstructCompoundPaths((size_t)options_.thread_size);
    if (options_.dump > 0)
        path_graph_.Dump(options_.output_directory + "/path_graph_3.txt");

    path_graph_.MarkRepeatBridge();

    path_graph_.IdentifyPaths(options_.select_branch);
}

void Assembly::SaveContigs() {

    FILE *fcontig_seqs = fopen(OutputPath("contigs.fasta").c_str(), "w");
    FILE *fcontig_tiles = fopen(OutputPath("contig_tiles").c_str(), "w");

    FILE *fbubble_seqs = fopen(OutputPath("bubbles.fasta").c_str(), "w");
    FILE *fbubble_tiles = fopen(OutputPath("bubble_tiles").c_str(), "w");


    int ctgid = 0;
    for (const auto &path : path_graph_.GetPaths()) {
        std::list<StringEdge*> pcontig;
        std::list<std::pair<CompoundPathEdge*, std::list<std::list<StringEdge*>>>> acontigs;
        
         
        for (const auto &p : path) {
            if (p->type_ == "simple") {
                SimplePathEdge *e = static_cast<SimplePathEdge*>(p);
                pcontig.insert(pcontig.end(), e->path_.begin(), e->path_.end());
            }
            else if (p->type_ == "compound") {
                CompoundPathEdge *e = static_cast<CompoundPathEdge*>(p);
                assert(e->simple_paths_.size() > 0);
                StringNode *in_node = string_graph_.GetNode(e->in_node_->id_);
                StringNode *out_node = string_graph_.GetNode(e->out_node_->id_);
                assert(in_node != nullptr && out_node != nullptr);
                std::unordered_set<StringEdge*> doable;
                for (auto i : e->simple_paths_) {
                    assert(i->type_ != "compound");
                    SimplePathEdge *s = static_cast<SimplePathEdge*>(i);

                    for (auto ss : s->path_) {
                        doable.insert(ss);
                    }
                }

                std::vector<StringEdge*> &&shortest = string_graph_.ShortestPath(in_node, out_node, doable, [](StringEdge* e) {return e->score_; });
                assert(shortest.size() > 0);
                pcontig.insert(pcontig.end(), shortest.begin(), shortest.end());

                std::list<std::list<StringEdge*>> actg;
                while (shortest.size() > 0) {
                    actg.push_back(std::list<StringEdge*>(shortest.begin(), shortest.end()));
                    for (auto s : shortest) doable.erase(s);
                    shortest = string_graph_.ShortestPath(in_node, out_node, doable);
                }
                acontigs.push_back(std::make_pair(e, std::move(actg)));
            }
            else {
                assert(!"never come here");
            }
        }
        if (ctgid % 2 == 0) {
            SaveContigs(fcontig_seqs, fcontig_tiles, ctgid, pcontig);
            SaveBubbles(fbubble_seqs, fbubble_tiles, ctgid, acontigs);
        }

        ctgid++;
    }
}

void Assembly::SaveContigs(FILE *fseq, FILE *ftile, int id, const std::list<StringEdge*> &contig) {
    assert(fseq != NULL && ftile != NULL);

    std::vector<std::string> seqs = ConstructContig1(contig);
    assert(seqs.size() >= 1);

    if (seqs[0].size() > 0) {
        if ((int)seqs[0].length() >= options_.min_contig_length) {
            fprintf(fseq, ">%06d%c %s length=%zd\n",
                id / 2,
                ((id % 2) == 0 ? 'F' : 'R'),
                contig.front()->in_node_ != contig.back()->out_node_ ? "linear" : "circular",
                seqs[0].size());

            fprintf(fseq, "%s\n", seqs[0].c_str());
        }
    }

    for (size_t i=1; i<seqs.size(); ++i) {
        if ((int)seqs[i].length() >= options_.min_contig_length) {
            fprintf(fseq, ">%06d%c_%zd  %s length=%zd\n",
                id / 2,
                ((id % 2) == 0 ? 'F' : 'R'),
                i,
                "stub",
                seqs[i].size());
            fprintf(fseq, "%s\n", seqs[i].c_str());
        }
    }

    
    for (auto e : contig) {
        fprintf(ftile, "%06d%c edge=%s~%s read=%s start=%d end=%d aligned=%d identity=%.02f\n",
            id / 2,
            ((id % 2) == 0 ? 'F' : 'R'),
            StringGraph::NodeIdString(e->in_node_->Id()).c_str(),
            StringGraph::NodeIdString(e->out_node_->Id()).c_str(),
            StringGraph::ReadIdString(e->read_).c_str(),
            e->start_,
            e->end_,
            e->score_,
            e->identity_);
    }
}

void Assembly::SaveBubbles(FILE *fseq, FILE* ftile, int ctgid, const std::list<std::pair<CompoundPathEdge*, std::list<std::list<StringEdge*>>>> &bubbles) {
    int bubble_index = 1;

    for (const auto &bubble : bubbles) {
        auto e = bubble.first;
        const auto &paths = bubble.second;

        std::vector<const std::list<StringEdge*>*> dpaths;
        std::vector<std::string> dseqs;
        std::vector<std::array<double, 2>> similars;

        for (const auto &path : paths) {
            if (path == paths.front()) {
                dseqs.push_back(ConstructContigStraight(paths.front()));
                dpaths.push_back(&path);
                similars.push_back({1.0, 1.0});
            } else {
                assert(dseqs.size() >= 1);
                std::string seq = ConstructContigStraight(path);
    
                if (dseqs.front().size() >= 2000 && seq.size() >= 2000) {
                    auto d = ComputeSequenceSimilarity(seq, dseqs.front()); // coverage, identity
                    
                    if (d[1]*100 <= options_.max_bubble_identity || d[0]*100 < options_.max_bubble_coverage ) {
                        dseqs.push_back(seq);
                        dpaths.push_back(&path);
                        similars.push_back(d);
                    }
                }
            }
        }


        if (dseqs.size() > 1) {
            assert(dseqs.size() == dpaths.size());
            for (size_t i=0; i<dseqs.size(); ++i) {
                for (auto p : *dpaths[i]) {

                    fprintf(ftile, "%06d%c-%03d-%02zd edge=%s~%s read=%d start=%d end=%d aligned_length=%d identity=%.02f\n",
                        ctgid / 2,
                        ((ctgid % 2) == 0 ? 'F' : 'R'),
                        bubble_index,
                        i,
                        StringGraph::NodeIdString(p->in_node_->Id()).c_str(),
                        StringGraph::NodeIdString(p->out_node_->Id()).c_str(),
                        p->out_node_->ReadId(),
                        p->start_,
                        p->end_,
                        p->score_,
                        p->identity_);
                }
                    
                fprintf(fseq, ">%06d%c-%03d-%02zd start=%s end=%s length=%zd size=%zd identity=%.02f coverage=%.02f\n",
                    ctgid / 2,
                    ((ctgid % 2) == 0 ? 'F' : 'R'),
                    bubble_index,
                    i,
                    StringGraph::NodeIdString(e->in_node_->Id()).c_str(),
                    StringGraph::NodeIdString(e->out_node_->Id()).c_str(),
                    dseqs[i].size(),
                    dpaths[i]->size(),
                    similars[i][1],
                    similars[i][0]);


            }

        }
        
        bubble_index++;
    }

}

std::string Assembly::ConstructContigStraight(const std::list<StringEdge*> &contig) {
    std::string seq;
    
    for (auto e : contig) {
        seq += EdgeToSeq(e);
    }
    return seq;
}

std::string Assembly::ConstructContig(const std::list<StringEdge*> &contig) {
    std::string seq;

    auto first = contig.front()->in_node_;
    // first->OutDegree() == 0: never happen
    // first->OutDegree() == 1: read of first node should be add to this contig
    // first->OutDegree() >  1: read of first node are not determined how to add
    if (first->OutDegree() == 1) {  
        // first->InDegree() == 0: Add the whole read
        // first->InDegree() == 1: Add the whole read. It is a circle, so it is add when dealing last
        // first->InDegree() >  1: Add the shortest in_edge.
        if (first->InDegree() == 0) {
            int read = contig.front()->in_node_->ReadId();
            std::string readseq = read_store_.GetSeq(read);
            seq += first->Id() < 0 ? readseq : Seq::ReverseComplement(readseq);

        } else if (first->InDegree() > 1) {
            auto m = std::min_element(first->in_edges_.begin(), first->in_edges_.end(), [](const StringEdge* a, const StringEdge*b) {
                return a->length_ < b->length_;
            });
            seq += EdgeToSeq(*m);
        }
    }

    for (auto e : contig) {
        if ( e != contig.back()) {
            seq += EdgeToSeq(e);
        }
    }

    auto last = contig.back()->out_node_;
    // last->InDegree() == 0: never happen
    // last->InDegree() == 1: Add the contig.back()
    // last->InDegree() >  1: contig.back() is repeat area, and should be converted to independent ctg. 
    if (last->InDegree() == 1) {
        // last->OutDegree() == 0: It is a end point
        // last->OutDegree() == 1: It is a circle, 
        // last->OutDegree() >  1: It has some branches
        seq += EdgeToSeq(contig.back());
    }

    return seq;
}


std::string Assembly::ConstructContigMain(const std::list<StringEdge*> &contig) {
    std::string seq;

    auto first = contig.front()->in_node_;
    // first->OutDegree() == 0: never happen
    // first->OutDegree() == 1: read of first node should be add to this contig
    // first->OutDegree() >  1: read of first node are not determined how to add
    if (first->OutDegree() == 1) {  
        // first->InDegree() == 0: Add the whole read
        // first->InDegree() == 1: Add the whole read. It is a circle, so it is add when dealing last
        // first->InDegree() >  1: Add the shortest in_edge.
        if (first->InDegree() == 0) {
            int read = first->ReadId();
            std::string readseq = read_store_.GetSeq(read);
            seq += first->Id() < 0 ? readseq : Seq::ReverseComplement(readseq);

        } else if (first->InDegree() > 1) {
            //auto m = std::min_element(first->in_edges_.begin(), first->in_edges_.end(), [](const StringEdge* a, const StringEdge*b) {
            //    return a->length_ < b->length_;
            //});
            //seq += EdgeToSeq(*m);
            int read = first->ReadId();
            std::string readseq = read_store_.GetSeq(read);
            seq += first->Id() < 0 ? readseq : Seq::ReverseComplement(readseq);
        }
    }

    for (auto e : contig) {
        if ( e != contig.back()) {
            seq += EdgeToSeq(e);
        }
    }

    auto last = contig.back()->out_node_;
    // last->InDegree() == 0: never happen
    // last->InDegree() == 1: Add the contig.back()
    // last->InDegree() >  1: contig.back() is repeat area, and should be converted to independent ctg. 
    if (last->InDegree() == 1) {
        // last->OutDegree() == 0: It is a end point
        // last->OutDegree() == 1: It is a circle, 
        // last->OutDegree() >  1: It has some branches
        seq += EdgeToSeq(contig.back());
    }

    return seq;
}


std::vector<std::string> Assembly::ConstructContig1(const std::list<StringEdge*> &contig) {
    std::vector<std::string> seqs;
    seqs.push_back(ConstructContig(contig));
     
    auto first = contig.front()->in_node_;
    // first->OutDegree() == 0: never happen
    // first->OutDegree() == 1: read has been add
    // first->OutDegree() >  1: independent ctg
    if (first->OutDegree() > 1 && first->out_edges_[0] == contig.front()) {
        // first->InDegree() == 0: Add the whole read
        // first->InDegree() == 1: read has been add
        // first->InDegree() >  1: convert when dealing last
        if (first->InDegree() == 0) {
            int read = contig.front()->in_node_->ReadId();
            std::string readseq = read_store_.GetSeq(read);
            seqs.push_back(first->Id() < 0 ? readseq : Seq::ReverseComplement(readseq));
        }
    }

    auto last = contig.back()->out_node_;
    // last->InDegree() == 0: never happen
    // last->InDegree() == 1: Has been added
    // last->InDegree() >  1: contig.back() is repeat area, and should be converted to independent ctg. 
    if (last->InDegree() > 1 && last->in_edges_[0] == contig.back()) {
        // last->OutDegree() == 0: The shortest in_edges should been convented to independent ctg.
        // last->OutDegree() == 1: The in_edge has been added
        // last->OutDegree() >  1: The shortest in_edges should been convented to independent ctg.
        if (last->OutDegree() == 0 || last->OutDegree() > 1) {
            auto m = std::min_element(last->in_edges_.begin(), last->in_edges_.end(), [](const StringEdge* a, const StringEdge*b) {
                return a->length_ < b->length_;
            });
            seqs.push_back(EdgeToSeq(*m));
        }
    }

    return seqs;
}


std::vector<std::string> Assembly::ConstructContigAll(const std::list<StringEdge*> &contig) {
    std::vector<std::string> seqs;
    seqs.push_back(ConstructContigMain(contig));
     
    auto first = contig.front()->in_node_;
    // first->OutDegree() == 0: never happen
    // first->OutDegree() == 1: read has been add
    // first->OutDegree() >  1: independent ctg
    if (first->OutDegree() > 1 && first->out_edges_[0] == contig.front()) {
        // first->InDegree() == 0: Add the whole read
        // first->InDegree() == 1: read has been add
        // first->InDegree() >  1: convert when dealing last
        if (first->InDegree() == 0) {
            int read = first->ReadId();
            std::string readseq = read_store_.GetSeq(read);
            seqs.push_back(first->Id() < 0 ? readseq : Seq::ReverseComplement(readseq));
        }
    }

    auto last = contig.back()->out_node_;
    // last->InDegree() == 0: never happen
    // last->InDegree() == 1: Has been added
    // last->InDegree() >  1: contig.back() is repeat area, and should be converted to independent ctg. 
    if (last->InDegree() > 1 && last->in_edges_[0] == contig.back()) {
        // last->OutDegree() == 0: The shortest in_edges should been convented to independent ctg.
        // last->OutDegree() == 1: The in_edge has been added
        // last->OutDegree() >  1: The shortest in_edges should been convented to independent ctg.
        if (last->OutDegree() == 0 || last->OutDegree() > 1) {
            //auto m = std::min_element(last->in_edges_.begin(), last->in_edges_.end(), [](const StringEdge* a, const StringEdge*b) {
            //    return a->length_ < b->length_;
            //});
            //seqs.push_back(EdgeToSeq(*m));
            int read = last->ReadId();
            std::string readseq = read_store_.GetSeq(read);
            seqs.push_back(first->Id() < 0 ? readseq : Seq::ReverseComplement(readseq));
        }
    }

    return seqs;
}

std::string Assembly::EdgeToSeq(const StringEdge *e) {

    int read = e->out_node_->ReadId();
    std::string readseq = read_store_.GetSeq(read);

    std::tuple<int, bool, int, int> area = e->GetSeqArea();

    if (!std::get<1>(area)) {
        assert(std::get<3>(area) > std::get<2>(area));
        return readseq.substr(std::get<2>(area), std::get<3>(area)-std::get<2>(area));
    } else {
        assert(std::get<2>(area) > std::get<3>(area));
        return Seq::ReverseComplement(readseq.substr(std::get<3>(area), std::get<2>(area)));

    } 
}

void Assembly::SaveGraph() {
    const std::string &output_directory = options_.output_directory;

    string_graph_.SaveEdges((output_directory+"//graph_edges"));
    path_graph_.SaveEdges((output_directory + "//graph_paths"));
}

std::array<double,2> Assembly::ComputeSequenceSimilarity(const std::string &qseq, const std::string &tseq) {
    SimpleAlign sa(tseq, 11);
    SimpleAlign::Result r = sa.Align(qseq, 500, false);
    if (r.target_end > r.target_start) {
        return std::array<double, 2>{1.0*(r.target_end - r.target_start) / tseq.size(), 
                1 - r.distance*1.0 / ((r.query_end - r.query_start + r.target_end - r.target_start)/2) };

    } else {
        return std::array<double, 2>{0, 0};
    }
}



void Assembly::PrintArguments() {
    LOG(INFO)("Arguments: \n%s", GetArgumentParser().PrintOptions().c_str());

}
